Related GANs and their SRGAN ablation experiments
本文将从三部分,即 GAN 模型的理论部分,代码(实践)部分及 SRGAN 的消融试验部分展开介绍
1. GAN(Generative Adversarial Network)生成对抗网络
核心:由两个神经网络——生成器(Generator)和判别器(Discriminator)组成,通过博弈过程相互提升。 · 生成器:试图“伪造”以假乱真的数据。 · 判别器:判断输入是真实数据还是生成器伪造的。 · 训练目标:生成器希望骗过判别器,判别器希望准确识别真假。 本质上是一个最大最小问题:
2. cGAN(Conditional GAN)条件生成对抗网络
核心:在 GAN 的基础上,引入“条件”信息(如标签、图像、文本等) · 生成器和判别器都接收条件变量 · G(z,y):在条件 y 下生成图像 · D(x,y):判断图像是否为在条件 y 下真实的 用途:图像翻译(如黑白图像上色)、语义图生成图像、文本生成图像 目标函数:
3. SRGAN
目的:图像超分辨率,即将低分辨率图像(LR)还原成高分辨率图像(HR) · 生成器结构:使用残差网络(ResNet)进行细节重建。 · 判别器:区分生成的高分图像和真实高分图像。 损失函数包含: · 内容损失(如 MSE 或感知损失); · 对抗损失(判别器输出) · 感知损失(Perceptual Loss):在 VGG 网络的高层 feature 上计算差异,更贴近人类视觉感受
4. ESRGAN
基于 SRGAN,具有如下优势:
- Residual-in-Residual Dense Block (RRDB):替换原 SRGAN 的残差块,结构更深,信息流更丰富。
- 对抗损失改进:采用 Relativistic average GAN(RaGAN),即判断“生成图是否比真实图更假”,而不是简单判断真假。
- 感知损失优化:使用未归一化的 VGG 特征图,避免图像过光滑。
- 训练技巧:使用多阶段训练,包括先训练内容损失,再加入对抗训练
总结(理论部分)
名称 | 全称 | 类型 | 特点概述 |
---|---|---|---|
GAN | Generative Adversarial Net | 无监督生成 | 对抗生成图像 |
cGAN | Conditional GAN | 条件生成 | 加入标签或条件进行控制 |
SRGAN | Super-Resolution GAN | 图像超分辨率 | 使用感知损失,生成自然高分图像 |
ESRGAN | Enhanced SRGAN | 图像超分辨率 | 加强网络结构和损失函数,细节更佳 |
在后文的试验中,原始代码与数据集皆存放在 GitHub 仓库:https://github.com/zqqqqqqj1110/GAN_WB
GAN 对抗神经网络及其变种(试验部分)
1. cGAN
本文以 cGAN 作为 baseline,后续的 gan 变种模型皆由该部分代码变换而来,因此在这部分会讲的较为全面一些,后文可能会省略
1.1 安装需要的包与环境
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import matplotlib.pyplot as plt
from pytorch_msssim import ms_ssim
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
作者使用的是 MacOS,因此使用了 mps,如果是 cuda 的话直接换成“cuda”即可,配置环境部分(gpu 环境)可转到我的个人博客处查阅(或者在这更,懒了,小鸽一下^_^)
1.2 数据预处理
首先需要加载原始的数据集,接着在低分辨率下生成数据(下采样)与保存(通过双线性插值的方法),最后计算 mean,std 等标准指标(都是后面需要用到的,为了计算指标,不如手动自己计算一下)
# === 1. 加载原始 HR 数据 ===
hr_train = np.load("seasonal_split/HR_data_train_tm_Summer.npy")[:200]
hr_valid = np.load("seasonal_split/HR_data_valid_tm_Summer.npy")[:200]
hr_test = np.load("seasonal_split/HR_data_test_tm_Summer.npy")[:200]
# === 2. 生成 LR 数据(双线性插值至 16×16) ===
def downsample(hr_array, scale=4):
tensor = torch.tensor(hr_array, dtype=torch.float32)
return F.interpolate(tensor, scale_factor=1/scale, mode="bilinear", align_corners=False).numpy()
lr_train = downsample(hr_train)
lr_valid = downsample(hr_valid)
lr_test = downsample(hr_test)
# === 3. 保存为 .npy 文件 ===
np.save("tm/HR_data_train_40.npy", hr_train)
np.save("tm/LR_data_train_40.npy", lr_train)
np.save("tm/HR_data_valid_40.npy", hr_valid)
np.save("tm/LR_data_valid_40.npy", lr_valid)
np.save("tm/HR_data_test_40.npy", hr_test)
np.save("tm/LR_data_test_40.npy", lr_test)
# === 4. mean和std===
mean = np.mean(hr_train, axis=(0, 2, 3))[:, None, None]
std = np.std(hr_train, axis=(0, 2, 3))[:, None, None]
np.save("tm/mean_40.npy", mean)
np.save("tm/std_40.npy", std)
# 每个通道的 min 和 max(例如 2 个通道)
hr_min = hr_train.min(axis=(0, 2, 3)) # shape: (2,)
hr_max = hr_train.max(axis=(0, 2, 3))
# 保存为 .npy 文件,后续评估使用
np.save("tm/min_40.npy", hr_min.astype(np.float32))
np.save("tm/max_40.npy", hr_max.astype(np.float32))
print("完成:生成 LR/HR 切片、保存归一化参数,包括 test")